import os
import logging
import argparse
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm 

# Configuração do logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class BalanceModel:
    def __init__(self, drenagem_file, disponibilidade_file, retiradas_file, efluentes_file, criterio, nt, scn_fix, scn_number):
        self.drenagem_file = drenagem_file
        self.disponibilidade_file = disponibilidade_file
        self.retiradas_file = retiradas_file
        self.efluentes_file = efluentes_file
        self.nt = nt
        self.scn_fix = scn_fix
        self.scn_number = scn_number
        self.criterio = criterio
        
        # Placeholders para os DataFrames e matrizes de dados
        self.df_drenagem = None
        self.df_disponibilidade = None
        self.df_retiradas = None        
        self.df_efluentes = None
        self.df_drl = None  # DataFrame resultante da união entre drenagem, disponibilidade e retiradas
        
        self.d_qnat = None  # Matriz de vazões naturais (disponibilidades)
        self.d_qwit = None  # Matriz de retiradas

    def read_data(self):
        try:
            logger.info("Iniciando leitura dos arquivos...")
            
            self.df_drenagem = pd.read_csv(self.drenagem_file, sep=";", low_memory=False, dtype={"CodBas_ID": int, "CodDown_ID": int, "Mini_ID": int})
            self.df_disponibilidade = pd.read_csv(self.disponibilidade_file, sep=";", low_memory=False, dtype=str)
            self.df_retiradas = pd.read_csv(self.retiradas_file, sep=";", low_memory=False, dtype=str).dropna(how="all")
            self.df_efluentes = pd.read_csv(self.efluentes_file, sep=";", low_memory=False, dtype={"cotrecho": int, "Q_Inflow": float})
            
            logger.info("Arquivos carregados com sucesso!")
            
            # Normaliza os valores da coluna 'cotrecho'
            for df in [self.df_drenagem, self.df_disponibilidade, self.df_retiradas, self.df_efluentes]:
                df["cotrecho"] = df["cotrecho"].astype(str).str.strip()
            
            # Processamento do DataFrame de disponibilidades:
            num_disp_cols = self.df_disponibilidade.shape[1] - 1  # Desconta a coluna "cotrecho"
            new_disp_columns = ["cotrecho"] + [f"Q_disp_{i}" for i in range(1, num_disp_cols + 1)]
            self.df_disponibilidade.columns = new_disp_columns
            
            # Realiza merge entre a drenagem e a disponibilidade com base na coluna "cotrecho"
            self.df_drl = self.df_drenagem.merge(self.df_disponibilidade, on="cotrecho", how="left")
            
            # Processamento do DataFrame de retiradas:
            num_wit_cols = self.df_retiradas.shape[1] - 1  # Desconta a coluna "cotrecho"
            new_wit_columns = ["cotrecho"] + [f"Q_wit_{i}" for i in range(1, num_wit_cols + 1)]
            self.df_retiradas.columns = new_wit_columns

            self.df_efluentes = self.df_efluentes[['cotrecho', 'Q_Inflow']]
            new_eff_columns = ["cotrecho", "Q_Inflow"]
            self.df_efluentes.columns = new_eff_columns
            self.df_efluentes.loc[:,"cotrecho"] = self.df_efluentes["cotrecho"].astype(str).str.strip()
            
            # Realiza merge entre o DataFrame atual (drenagem+disponibilidade) e as retiradas com base na coluna "cotrecho"
            self.df_drl = self.df_drl.merge(self.df_retiradas, on="cotrecho", how="left")
            
            #Realiza merge entre o Dataframe com drenagem, disponibilidade e retirada com efluentes
            self.df_drl = self.df_drl.merge(self.df_efluentes, on="cotrecho", how="left")
          
            self.d_codbas = self.df_drl['CodBas_ID'].values
            self.d_codjus = self.df_drl['CodDown_ID'].values
            self.d_ord = self.df_drl['Order_ID'].values
            self.d_mini = self.df_drl['Mini_ID'].values
            self.d_codres = self.df_drl['Reserv_ID'].values
            self.d_length = self.df_drl['Length_km'].values
            self.Res = self.df_drl['Reserv_ID'].values.astype(int)
            
            print(f"Total de cotrechos na drenagem: {len(self.df_drenagem)}")
            print(f"Total de cotrechos na disponibilidade: {len(self.df_disponibilidade)}")
            print(f"Total de cotrechos nas retiradas: {len(self.df_retiradas)}")
            print(f"Total de cotrechos nos efluentes: {len(self.df_efluentes)}")
            
            logger.info("Leitura e união dos dados concluída com sucesso.")
        except Exception as e:
            logger.error(f"Erro ao ler e unir arquivos de dados: {e}")
            raise

    def associate_scenarios(self):
        """
        Utiliza os dados já carregados e unidos (df_drl), onde:
          - As vazões naturais passam a ser as disponibilidades (Q_disp_n)
          - As retiradas estão disponíveis nas colunas renomeadas como Q_wit_n
        """
        try:
            logger.info("Iniciando associação de cenários...")

            # Disponibilidades:
            if self.scn_fix == "str":
                fixed_flow = self.df_drl[f'Q_disp_{self.scn_number}'].values.astype(float).reshape(-1, 1)
                self.d_qnat = np.tile(fixed_flow, (1, self.nt))
            else:
                Qdisp_cols = [f'Q_disp_{i}' for i in range(1, self.nt + 1)]
                self.d_qnat = self.df_drl[Qdisp_cols].astype(float).values

            # Retiradas:
            self.df_wit = self.df_drl  
            if self.scn_fix == "wit":
                fixed_withdrawal = self.df_wit[f'Q_wit_{self.scn_number}'].values.astype(float).reshape(-1, 1)
                self.d_qwit = np.tile(fixed_withdrawal, (1, self.nt))
            else:
                Qwit_cols = [f'Q_wit_{i}' for i in range(1, self.nt + 1)]
                self.d_qwit = self.df_wit[Qwit_cols].astype(float).values

            logger.info("Associação de cenários concluída com sucesso.")
        except Exception as e:
            logger.error(f"Erro ao associar cenários: {e}")
            raise

    def run_model(self):
        try:
            logger.info("Executando modelo de balanço hídrico...")
            model_start_time = datetime.datetime.now()
            logger.info(f"Início do processamento no run_model: {model_start_time}")
            
            n_d = self.df_drl.shape[0]
            nt = self.nt

            # Inicializa as matrizes de fluxo e déficit
            d_qmn   = np.zeros((n_d, nt))
            self.d_qmr   = np.zeros((n_d, nt))
            d_qmd   = np.zeros((n_d, nt))
            self.d_qcat  = np.zeros((n_d, nt))
            self.d_qout   = np.zeros((n_d, nt))
            self.d_qdef   = np.zeros((n_d, nt))
            self.d_qdefacm = np.zeros((n_d, nt))
            self.d_wbal   = np.zeros((n_d, nt))
            self.d_qnatres = np.zeros((n_d, nt))
            self.d_qwitacm = np.zeros ((n_d,nt))            
            self.d_qeff = np.zeros((n_d, nt))

            # Preenche a matriz de efluente com os valores 
            self.d_qeff[:] = np.tile(self.df_drl[f'Q_Inflow'].values.astype(float).reshape(-1, 1), (1, nt))
            self.d_qeff = np.nan_to_num(self.d_qeff, nan=0)

            # Ordena os trechos de acordo com Mini_ID (montante para jusante)
            sorted_indices = np.argsort(self.d_mini)

            # Processa cada trecho
            for idx in tqdm(sorted_indices, desc="Processando trechos", total=len(sorted_indices)):
                idr = idx
                # Se o trecho é de cabeceira, as variáveis de fluxo a montante são zero
                if self.d_ord[idr] == 1:
                    for it in range(nt):
                        d_qmn[idr, it] = 0
                        self.d_qmr[idr, it] = 0
                        d_qmd[idr, it] = 0
                else:
                    # Identifica os trechos a montante (onde CodDown_ID equivale ao CodBas_ID deste trecho)
                    ind_jus = np.where(self.d_codjus == self.d_codbas[idr])[0]
                    up_count = ind_jus.size
                    for it in range(nt):
                        if up_count in [2, 3, 4, 5]:
                            d_qmn[idr, it]  = np.sum(self.d_qnat[ind_jus, it])
                            self.d_qmr[idr, it]  = np.sum(self.d_qout[ind_jus, it]) if np.any(self.d_qout[ind_jus, it]) else 0
                            d_qmd[idr, it]  = np.sum(self.d_qdefacm[ind_jus, it]) if np.any(self.d_qdefacm[ind_jus, it]) else 0
                        else:
                            d_qmn[idr, it]  = np.sum(self.d_qnat[ind_jus, it])
                            self.d_qmr[idr, it]  = np.sum(self.d_qout[ind_jus, it]) if np.any(self.d_qout[ind_jus, it]) else 0
                            d_qmd[idr, it]  = np.sum(self.d_qdefacm[ind_jus, it]) if np.any(self.d_qdefacm[ind_jus, it]) else 0

                # Verificação se é linha de costa. Se não for, cálculos seguem normais.
                for it in range(nt):
                    if self.d_qnat[idr, it] == 9999:
                        d_qmn[idr, it]      = 9999
                        self.d_qcat[idr, it]    = 9999
                        self.d_qmr[idr, it]     = 9999
                        self.d_qwit[idr, it]    = 9999
                        self.d_qnatres[idr, it] = 9999
                        self.d_qdef[idr, it]    = 9999
                        self.d_qdefacm[idr, it] = 9999
                        self.d_qout[idr, it]    = 9999
                        self.d_wbal[idr, it]    = 9999
                    else:                       
                            self.d_qcat[idr, it] = self.d_qnat[idr, it] - d_qmn[idr, it]
                            
                            # Cálculo do fluxo final para o trecho
                            # Caso for trecho com reservatório, os cálculos de qout e qnatres são modificados.
                            if self.Res[idr] == 1:
                                self.d_qout[idr, it] = self.d_qnat[idr, it]-self.d_qwit[idr,it]
                                self.d_qnatres[idr, it] = self.d_qnat[idr,it]
                            else:
                                if self.d_qcat [idr,it] < 0:
                                    self.d_qcat [idr,it] = 0
                                self.d_qout[idr, it] = self.d_qmr[idr, it] + self.d_qcat[idr, it] - self.d_qwit[idr, it] + self.d_qeff[idr,it]
                                self.d_qnatres[idr, it] = self.d_qmr[idr, it] + self.d_qcat[idr, it]


                            if self.d_qout[idr, it] < ((1-self.criterio)*self.d_qnatres[idr, it]):
                                self.d_qout[idr, it] = ((1-self.criterio)*self.d_qnatres[idr, it])
                                self.d_qdef[idr, it] = self.d_qwit[idr, it] - (self.criterio*self.d_qnatres[idr, it])
                            else:
                                self.d_qdef[idr, it] = 0
                            
                            # Déficit acumulado e Retirada Acumulada
                            self.d_qdefacm[idr, it] = d_qmd[idr, it] + self.d_qdef[idr, it]
                            if self.d_ord[idr] == 1:
                                self.d_qwitacm [idr,it] = self.d_qwit[idr,it]
                            else:
                                self.d_qwitacm [idr,it] = np.sum (self.d_qwitacm [ind_jus, it]) + self.d_qwit [idr,it]
                            
                            # Balanço hídrico final (em %)
                            if self.d_qnatres[idr, it] != 0:
                                self.d_wbal[idr, it] = (self.d_qwit[idr, it] / self.d_qnatres[idr, it]) * 100
                            else:
                                self.d_wbal[idr, it] = 0


            model_finish_time = datetime.datetime.now()
            logger.info(f"Final do processamento no run_model: {model_finish_time}")

            return self.d_qout, self.d_wbal, self.d_qdefacm, self.d_codbas, self.d_qcat, self.d_qmr, self.d_qwitacm, self.d_qdef
        except Exception as e:
            logger.error(f"Erro durante a execução do modelo: {e}")
            raise

def main():
    parser = argparse.ArgumentParser(description="Modelo de Balanço Hídrico – Versão Completa")
    parser.add_argument("--nt", type=int, default=1, help="Número de períodos")
    parser.add_argument("--scn_fix", type=str, choices=["str", "wit"], default="wit", help="Modo de cenário")
    parser.add_argument("--scn_number", type=int, default=1, help="Número do cenário fixo")
    parser.add_argument("--criterio", type=float, default=1, help="Cenário limitando retiradas de acordo com a vazão de outorga")
    args = parser.parse_args()

    start_time = datetime.datetime.now()
    logger.info(f"Início do processamento: {start_time}")

    # Definição dos caminhos dos arquivos de entrada
    drenagem_path = os.path.join(os.getcwd(), "entrada", "drenagem.csv")
    disponibilidade_path = os.path.join(os.getcwd(), "entrada", "disponibilidade.csv")
    retiradas_path = os.path.join(os.getcwd(), "entrada", "retiradas.csv")
    efluentes_path = os.path.join(os.getcwd(), "entrada", "efluentes.csv")

    # Criação e execução do modelo 
    model = BalanceModel(drenagem_file=drenagem_path,
                         disponibilidade_file=disponibilidade_path,
                         retiradas_file=retiradas_path,
                         efluentes_file=efluentes_path,
                         criterio = args.criterio,
                         nt=args.nt,
                         scn_fix=args.scn_fix,
                         scn_number=args.scn_number)
    
    try:
        model.read_data()
        model.associate_scenarios()
        d_qout, d_wbal, d_qdefacm, d_codbas, d_qcat, d_qmr, d_qwitacm, d_qdef = model.run_model()
        
        # Salva os resultados: adiciona os campos calculados ao DataFrame da rede de drenagem
        df_results = model.df_drl.copy()

        # Cria um dicionário com as novas colunas calculadas
        novas_colunas = {}
        novas_colunas[f"Q_Witacm"]    = d_qwitacm[:, 0]
        for i in range(1, args.nt + 1):
            novas_colunas[f"Q_Out_{i}"]   = d_qout[:, i - 1]
            novas_colunas[f"W_Bal_{i}"]   = d_wbal[:, i - 1]
            novas_colunas[f"QDefacm_{i}"] = d_qdefacm[:, i - 1]
            novas_colunas[f"Q_Cat_{i}"]    = d_qcat[:, i - 1]
            novas_colunas[f"Q_Mr_{i}"]     = d_qmr[:, i - 1]

            
        df_novas = pd.DataFrame(novas_colunas)
        df_results = pd.concat([df_results, df_novas], axis=1)

        # Remove as colunas dos cenários não fixados, conforme o tipo de cenário escolhido:
        if args.scn_fix == "wit":
            cols_to_drop = [col for col in df_results.columns 
                            if col.startswith("Q_wit_") and col != f"Q_wit_{args.scn_number}"]
            df_results = df_results.drop(columns=cols_to_drop)
        elif args.scn_fix == "str":
            cols_to_drop = [col for col in df_results.columns 
                            if col.startswith("Q_disp_") and col != f"Q_disp_{args.scn_number}"]
            df_results = df_results.drop(columns=cols_to_drop)

        output_file = os.path.join(os.getcwd(), "saida", "balanco_hidrico.csv")
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        df_results.to_csv(output_file, index=False)
        logger.info(f"Resultados salvos em: {output_file}")
    except Exception as e:
        logger.error(f"Erro durante a execução do script: {e}")

    finish_time = datetime.datetime.now()
    logger.info(f"Final do processamento: {finish_time}")

if __name__ == "__main__":
    main()
